import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision

import itertools
import functools

from math import sqrt

import numpy as np
import heapq
from functools import partial
from collections import namedtuple

import smooth_dp_utils
import utils
import data_utils

from tqdm import tqdm
import time




class CombRenset18(nn.Module):

    def __init__(self, out_features, in_channels):
        super().__init__()
        self.resnet_model = torchvision.models.resnet18(pretrained=False, num_classes=out_features)
        del self.resnet_model.conv1
        self.resnet_model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        output_shape = (int(sqrt(out_features)), int(sqrt(out_features)))
        self.pool = nn.AdaptiveMaxPool2d(output_shape)
        #self.last_conv = nn.Conv2d(128, 1, kernel_size=1,  stride=1)


    def forward(self, x):
        x = self.resnet_model.conv1(x)
        x = self.resnet_model.bn1(x)
        x = self.resnet_model.relu(x)
        x = self.resnet_model.maxpool(x)
        x = self.resnet_model.layer1(x)
        #x = self.resnet_model.layer2(x)
        #x = self.resnet_model.layer3(x)
        #x = self.last_conv(x)
        x = self.pool(x)
        x = x.mean(dim=1)
        return x
    
    
def neighbours_8(x, y, x_max, y_max):
    deltas_x = (-1, 0, 1)
    deltas_y = (-1, 0, 1)
    for (dx, dy) in itertools.product(deltas_x, deltas_y):
        x_new, y_new = x + dx, y + dy
        if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0):
            yield x_new, y_new


def neighbours_4(x, y, x_max, y_max):
    for (dx, dy) in [(1, 0), (0, 1), (0, -1), (-1, 0)]:
        x_new, y_new = x + dx, y + dy
        if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0):
            yield x_new, y_new


def get_neighbourhood_func(neighbourhood_fn):
    if neighbourhood_fn == "4-grid":
        return neighbours_4
    elif neighbourhood_fn == "8-grid":
        return neighbours_8
    else:
        raise Exception(f"neighbourhood_fn of {neighbourhood_fn} not possible")
    
    
               
DijkstraOutput = namedtuple("DijkstraOutput", ["shortest_path", "is_unique", "transitions"])
def dijkstra(matrix, neighbourhood_fn="8-grid", request_transitions=False):

    x_max, y_max = matrix.shape
    neighbors_func = partial(get_neighbourhood_func(neighbourhood_fn), x_max=x_max, y_max=y_max)

    costs = np.full_like(matrix, 1.0e10)
    costs[0][0] = matrix[0][0]
    num_path = np.zeros_like(matrix)
    num_path[0][0] = 1
    priority_queue = [(matrix[0][0], (0, 0))]
    certain = set()
    transitions = dict()

    while priority_queue:
        cur_cost, (cur_x, cur_y) = heapq.heappop(priority_queue)
        if (cur_x, cur_y) in certain:
            pass

        for x, y in neighbors_func(cur_x, cur_y):
            if (x, y) not in certain:
                if matrix[x][y] + costs[cur_x][cur_y] < costs[x][y]:
                    costs[x][y] = matrix[x][y] + costs[cur_x][cur_y]
                    heapq.heappush(priority_queue, (costs[x][y], (x, y)))
                    transitions[(x, y)] = (cur_x, cur_y)
                    num_path[x, y] = num_path[cur_x, cur_y]
                elif matrix[x][y] + costs[cur_x][cur_y] == costs[x][y]:
                    num_path[x, y] += 1

        certain.add((cur_x, cur_y))
    # retrieve the path
    cur_x, cur_y = x_max - 1, y_max - 1
    on_path = np.zeros_like(matrix)
    on_path[-1][-1] = 1
    while (cur_x, cur_y) != (0, 0):
        cur_x, cur_y = transitions[(cur_x, cur_y)]
        on_path[cur_x, cur_y] = 1.0

    is_unique = num_path[-1, -1] == 1

    if request_transitions:
        return DijkstraOutput(shortest_path=on_path, is_unique=is_unique, transitions=transitions)
    else:
        return DijkstraOutput(shortest_path=on_path, is_unique=is_unique, transitions=None)
    
    
    
def nodes_to_M_batch(nodes):
    batch_size, N, _ = nodes.shape
    M_batch = torch.full((batch_size, N**2, N**2), 10000., device=nodes.device)
    
    I_range = torch.arange(N**2, device=nodes.device)
    J_range = torch.arange(N**2, device=nodes.device)
    I, J = torch.meshgrid(I_range, J_range, indexing='ij')

    abs_diff = torch.abs(I - J)

    cond1 = (abs_diff == 1) | (abs_diff == N-1) | (abs_diff == N) | (abs_diff == N+1)
    cond2 = (J % N != N-1) | (I % N != 0)
    cond3 = (I % N != N-1) | (J % N != 0)

    mask = cond1 & cond2 & cond3

    batch_indices = torch.arange(batch_size, device=nodes.device)[:, None, None]
    J_masked = J[mask].view(1, -1)

    M_batch[batch_indices, mask] = nodes[batch_indices, 
                                         torch.div(J_masked, N, rounding_mode='trunc'), 
                                         J_masked % N]   
    return M_batch



def get_M_indices(N):
    M = torch.full((N**2, N**2), float('inf'))
    
    I_range = torch.arange(N**2)
    J_range = torch.arange(N**2)
    I, J = torch.meshgrid(I_range, J_range, indexing='ij')

    abs_diff = torch.abs(I - J)

    cond1 = (abs_diff == 1) | (abs_diff == N-1) | (abs_diff == N) | (abs_diff == N+1)
    cond2 = (J % N != N-1) | (I % N != 0)
    cond3 = (I % N != N-1) | (J % N != 0)

    mask = cond1 & cond2 & cond3
    
    M[mask] = 0.
    
    idx_edges = torch.where(M<1000.)
    n_edges = idx_edges[0].shape[0]
    M_indices = torch.zeros((n_edges, 2))
    M_indices[:,0] = idx_edges[0]
    M_indices[:,1] = idx_edges[1]
    M_indices = torch.tensor(M_indices, dtype=torch.long)
    return M_indices


def get_path_nodes(M_indices, grid, st=-1, en=-1):

    #import pdb
    #pdb.set_trace()
    N = grid.shape[0] 
    
    if st==-1 and en==-1:
        st=0
        en=N**2-1

    valid_edges = []
    for edge in M_indices:
        if grid[torch.div(edge[0], N, rounding_mode='trunc'), edge[0] % N] == 1 and grid[torch.div(edge[1], N, rounding_mode='trunc'), edge[1] % N] == 1:
            valid_edges.append(edge)


    valid_edges = torch.stack(valid_edges)
    valid_edges

    from collections import deque

    def bfs(graph, start, end):
        queue = deque([start])
        visited = set()
        paths = {start: [start]}

        while queue:
            current_node = queue.popleft()

            if current_node == end:
                return paths[current_node]

            for neighbor in graph[current_node]:
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append(neighbor)
                    paths[neighbor] = paths[current_node] + [neighbor]
        return None

    graph = {i: [] for i in range(N**2)}
    for edge in valid_edges:
        graph[edge[0].item()].append(edge[1].item())

    path_indices = bfs(graph, st, en)

    return path_indices